/*
 * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*! \file
    \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/

#pragma once

#include "cutlass/cutlass.h"

#include "cutlass/arch/arch.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"

#include <type_traits>

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass
{
namespace gemm
{
namespace kernel
{

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace detail
{
template <typename>
inline constexpr bool dependent_false_v = false;
}

template <typename Mma_,          ///! Threadblock-scoped matrix multiply-accumulate
    typename Epilogue_,           ///! Epilogue
    typename ThreadblockSwizzle_, ///! Threadblock swizzling function
    typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
                         /// arch.
    bool SplitKSerial    ///! If true, code supporting split-K via serial reduction is enabled.
    >
struct GemmFpAIntB
{

    using Mma = Mma_;
    using Epilogue = Epilogue_;
    using EpilogueOutputOp = typename Epilogue::OutputOp;
    using ThreadblockSwizzle = ThreadblockSwizzle_;
    static bool const kSplitKSerial = SplitKSerial;

    using ElementA = typename Mma::IteratorA::Element;
    using LayoutA = typename Mma::IteratorA::Layout;
    using ElementB = typename Mma::IteratorB::Element;
    using LayoutB = typename Mma::IteratorB::Element;
    using ElementC = typename Epilogue::OutputTileIterator::Element;
    using LayoutC = typename Mma::LayoutC;
    using ElementScale = ElementC;

    static ComplexTransform const kTransformA = Mma::kTransformA;
    static ComplexTransform const kTransformB = Mma::kTransformA;

    // Type definitions about the mainloop.
    using Operator = typename Mma::Operator;
    using OperatorClass = typename Mma::Operator::OperatorClass;
    using ThreadblockShape = typename Mma::Shape;
    using WarpShape = typename Mma::Operator::Shape;
    using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
    using ArchTag = typename Mma::ArchTag;

    static int const kStages = Mma::kStages;
    static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
    static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
    static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;

    /// Warp count (concept: GemmShape)
    using WarpCount = typename Mma::WarpCount;
    static int const kThreadCount = 32 * WarpCount::kCount;

    static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;

    /// Parameters structure
    struct Arguments
    {
        GemmUniversalMode mode = GemmUniversalMode::kGemm;

        cutlass::gemm::GemmCoord problem_size;
        int group_size;
        typename Mma::IteratorA::TensorRef ref_A;
        typename Mma::IteratorB::TensorRef ref_B;
        typename Mma::IteratorScale::TensorRef ref_scale;
        typename Mma::IteratorScale::TensorRef ref_zero;
        typename Epilogue::OutputTileIterator::TensorRef ref_C;
        typename Epilogue::OutputTileIterator::TensorRef ref_D;

        // Control serial split-k
        int batch_count;

        typename EpilogueOutputOp::Params output_op;

        // For gather+scatter operations
        int const* gather_A_indices;
        int const* gather_B_indices;
        int const* scatter_D_indices;

        // Included so we can use Gemm Universal
        int batch_stride_D = 0;

        //
        // Methods
        //

        CUTLASS_HOST_DEVICE
        Arguments() {}

        CUTLASS_HOST_DEVICE
        Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size,
            typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B,
            typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero,
            typename Epilogue::OutputTileIterator::TensorRef ref_C,
            typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor,
            typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(),
            int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr,
            int const* scatter_D_indices = nullptr)
            : problem_size(problem_size)
            , group_size(group_size)
            , ref_A(ref_A)
            , ref_B(ref_B)
            , ref_scale(ref_scale)
            , ref_zero(ref_zero)
            , ref_C(ref_C)
            , ref_D(ref_D)
            , batch_count(serial_split_k_factor)
            , output_op(output_op)
            , gather_A_indices(gather_A_indices)
            , gather_B_indices(gather_B_indices)
            , scatter_D_indices(scatter_D_indices)
        {
        }
    };

    /// Parameters structure
    struct Params
    {
        cutlass::gemm::GemmCoord problem_size;
        int group_size;
        cutlass::gemm::GemmCoord grid_tiled_shape;
        int swizzle_log_tile;
        typename Mma::IteratorA::Params params_A;
        typename Mma::IteratorA::TensorRef ref_A;
        typename Mma::IteratorB::Params params_B;
        typename Mma::IteratorB::TensorRef ref_B;
        typename Mma::IteratorScale::Params params_scale;
        typename Mma::IteratorScale::TensorRef ref_scale;
        typename Mma::IteratorScale::TensorRef ref_zero;
        typename Epilogue::OutputTileIterator::Params params_C;
        typename Epilogue::OutputTileIterator::TensorRef ref_C;
        typename Epilogue::OutputTileIterator::Params params_D;
        typename Epilogue::OutputTileIterator::TensorRef ref_D;
        typename EpilogueOutputOp::Params output_op;
        int* semaphore;
        int gemm_k_size;
        // For gather+scatter operations
        int const* gather_A_indices;
        int const* gather_B_indices;
        int const* scatter_D_indices;

        //
        // Methods
        //

        CUTLASS_HOST_DEVICE
        Params()
            : swizzle_log_tile(0)
            , semaphore(0)
            , gemm_k_size(0)
        {
        }

        CUTLASS_HOST_DEVICE
        Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size,
            void* workspace = nullptr)
            : problem_size(args.problem_size)
            , group_size(args.group_size)
            , grid_tiled_shape(grid_tiled_shape)
            , swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape))
            , params_A(args.ref_A.layout())
            , ref_A(args.ref_A)
            , params_B(args.ref_B.layout())
            , ref_B(args.ref_B)
            , params_scale(args.ref_scale.layout())
            , ref_scale(args.ref_scale)
            , ref_zero(args.ref_zero)
            , params_C(args.ref_C.layout())
            , ref_C(args.ref_C)
            , params_D(args.ref_D.layout())
            , ref_D(args.ref_D)
            , output_op(args.output_op)
            , semaphore(static_cast<int*>(workspace))
            , gemm_k_size(gemm_k_size)
            , gather_A_indices(args.gather_A_indices)
            , gather_B_indices(args.gather_B_indices)
            , scatter_D_indices(args.scatter_D_indices)
        {
        }
    };

    /// Shared memory storage structure
    union SharedStorage
    {
        typename Mma::SharedStorage main_loop;
        typename Epilogue::SharedStorage epilogue;
    };

    //
    // Methods
    //

    CUTLASS_HOST_DEVICE
    GemmFpAIntB() {}

    /// Determines whether kernel satisfies alignment
    static Status can_implement(Arguments const& args)
    {
        static int const kAlignmentA
            = (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ? 32
            : (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value)
            ? 64
            : Mma::IteratorA::AccessType::kElements;
        static int const kAlignmentB
            = (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<32>>::value) ? 32
            : (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<64>>::value)
            ? 64
            : Mma::IteratorB::AccessType::kElements;

        static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements;

        static int const kAlignmentC = (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
                                           layout::ColumnMajorInterleaved<32>>::value)
            ? 32
            : (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
                  layout::ColumnMajorInterleaved<64>>::value)
            ? 64
            : Epilogue::OutputTileIterator::kElementsPerAccess;

        if (!TensorRef_aligned(args.ref_A, kAlignmentA))
        {
            return Status::kErrorMisalignedOperand;
        }

        if (!TensorRef_aligned(args.ref_B, kAlignmentB))
        {
            return Status::kErrorMisalignedOperand;
        }

        if (!TensorRef_aligned(args.ref_scale, kAlignmentScale))
        {
            return Status::kErrorMisalignedOperand;
        }

        if (!TensorRef_aligned(args.ref_zero, kAlignmentScale))
        {
            return Status::kErrorMisalignedOperand;
        }

        if (!TensorRef_aligned(args.ref_C, kAlignmentC))
        {
            return Status::kErrorMisalignedOperand;
        }

        if (!TensorRef_aligned(args.ref_D, kAlignmentC))
        {
            return Status::kErrorMisalignedOperand;
        }

        if (!args.ref_scale.good())
        {
            return Status::kErrorNotSupported;
        }

        if constexpr (hasZero(Mma::QuantOp))
        {
            if (!args.ref_zero.good())
            {
                return Status::kErrorNotSupported;
            }
        }
        else
        {
            if (args.ref_zero.good())
            {
                return Status::kErrorNotSupported;
            }
        }

        if constexpr (isFinegrained(Mma::QuantOp))
        {
            if (args.group_size != 64 && args.group_size != 128)
            {
                return Status::kErrorNotSupported;
            }
        }

        return Status::kSuccess;
    }

    static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
    {

        return 0;
    }

    // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator
    // has a different constructor signature than a regular cutlass iterator
    template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<isFinegrained(op), bool> = true>
    CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
        typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
        typename IteratorScale::TensorCoord extent, int thread_id,
        typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
    {

        return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size);
    }

    template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<!isFinegrained(op), bool> = true>
    CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
        typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
        typename IteratorScale::TensorCoord extent, int thread_id,
        typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
    {

        return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset);
    }

    CUTLASS_DEVICE
    void run_kernel_(Params const& params, SharedStorage& shared_storage)
    {
        using LayoutB = typename Mma::IteratorB::Layout;
        static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
                || platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
            "B must be row major/col major OR col major interleaved.");

        // Compute threadblock location
        ThreadblockSwizzle threadblock_swizzle;

        cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);

        // Early exit if CTA is out of range
        if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
            || params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
        {

            return;
        }

        // Compute initial location in logical coordinates
        cutlass::MatrixCoord tb_offset_A{
            threadblock_tile_offset.m() * Mma::Shape::kM,
            threadblock_tile_offset.k() * params.gemm_k_size,
        };

        cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
            threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};

        typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64;
        typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0;
        cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN};

        // Problem size is a function of threadblock index in the K dimension
        int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);

        // Compute threadblock-scoped matrix multiply-add
        int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;

        // Compute position within threadblock
        int thread_idx = threadIdx.x;

        // Construct iterators to A and B operands
        typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(),
            {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices);

        typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(),
            {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B,
            params.gather_B_indices);

        typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1;
        typename Mma::IteratorScale iterator_scale = initialize_scale<typename Mma::IteratorScale, Mma::QuantOp>(
            params.params_scale, params.ref_scale.data(), params.ref_zero.data(),
            {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size);

        // Broadcast the warp_id computed by lane 0 to ensure dependent code
        // is compiled as warp-uniform.
        int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
        int lane_idx = threadIdx.x % 32;

        //
        // Main loop
        //
        // Construct thread-scoped matrix multiply
        Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);

        typename Mma::FragmentC accumulators;

        accumulators.clear();

        if (!kSplitKSerial || gemm_k_iterations > 0)
        {
            // Compute threadblock-scoped matrix multiply-add
            mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
        }

        //
        // Epilogue
        //

        EpilogueOutputOp output_op(params.output_op);

        //
        // Masked tile iterators constructed from members
        //

        threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);

        // assume identity swizzle
        MatrixCoord threadblock_offset(
            threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);

        int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();

        // Construct the semaphore.
        Semaphore semaphore(params.semaphore + block_idx, thread_idx);

        // If performing a reduction via split-K, fetch the initial synchronization
        if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
        {

            // Fetch the synchronization lock initially but do not block.
            semaphore.fetch();

            // Indicate which position in a serial reduction the output operator is currently updating
            output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
        }

        // Tile iterator loading from source tensor.
        typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(),
            thread_idx, threadblock_offset, params.scatter_D_indices);

        // Tile iterator writing to destination tensor.
        typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(),
            thread_idx, threadblock_offset, params.scatter_D_indices);

        Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);

        // Wait on the semaphore - this latency may have been covered by iterator construction
        if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
        {

            // For subsequent threadblocks, the source matrix is held in the 'D' tensor.
            if (threadblock_tile_offset.k())
            {
                iterator_C = iterator_D;
            }

            semaphore.wait(threadblock_tile_offset.k());
        }

        // Execute the epilogue operator to update the destination tensor.
        epilogue(output_op, iterator_D, accumulators, iterator_C);

        //
        // Release the semaphore
        //

        if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
        {

            int lock = 0;
            if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1)
            {

                // The final threadblock resets the semaphore for subsequent grids.
                lock = 0;
            }
            else
            {
                // Otherwise, the semaphore is incremented
                lock = threadblock_tile_offset.k() + 1;
            }

            semaphore.release(lock);
        }
    }

    template <typename CompilationArch>
    CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
    {
        if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
        {
            run_kernel_(params, shared_storage);
        }
        else
        {
            CUTLASS_NOT_IMPLEMENTED();
        }
    }

    /*
        To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
        to the ArchTag of the cutlass kernel operator.
      */
    /// Executes one GEMM
    CUTLASS_DEVICE
    void operator()(Params const& params, SharedStorage& shared_storage)
    {
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
        run_kernel<arch::Sm75>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
        run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ == 890)
        run_kernel<arch::Sm89>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 1000)
        // Use SM80 implementation for GB10x, GB20x.
        run_kernel<arch::Sm80>(params, shared_storage);
#else
        CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
#endif
#else
        CUTLASS_NOT_IMPLEMENTED();
#endif
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace kernel
} // namespace gemm
} // namespace cutlass
